'''
This code is used to calculate fidelity of state detection measurements using likelihood functions. 

For theoretical details see:
    - Sinhal, M.: Quantum control of single molecular ions. PhD thesis, University of Basel (2021), Sec. 5.4.2 (p. 142-145)
    - Supplementary Material of: Sinhal, Meir, Najafian, Hegi and Willitsch, Science 367(6483), 1213–1218 (2020)
'''

import numpy as np
import math
import matplotlib.pyplot as plt

### ---  Parameters for calculations --- ###
N = 9  # number of experiments                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             #number of trials
p_alpha = 0.95
p_beta = 0.05
k_values = np.arange(0, N+1)


def likelihood_p_beta(N, k, p_beta):
    coefficient = math.factorial(N) / (math.factorial(k) * math.factorial(N - k))
    return coefficient * (p_beta ** k) * ((1 - p_beta) ** (N - k))

def likelihood_p_alpha(N, k, p_alpha):
    coefficient = math.factorial(N) / (math.factorial(k) * math.factorial(N - k))
    return coefficient * (p_alpha ** k) * ((1 - p_alpha) ** (N - k))

def calculate_k_t(N, p_alpha, p_beta):
    denominator = np.log(p_alpha / p_beta) / np.log((1 - p_beta) / (1 - p_alpha))
    k_t = np.floor(N / (denominator + 1))
    # print('Th:', N / (denominator + 1))
    return int(k_t)

def calculate_epsilon_alpha_beta(N, p_alpha, p_beta):
    
    # threshold
    k_t = calculate_k_t(N, p_alpha, p_beta)
    
    # epsilon_alpha
    epsilon_alpha = sum(likelihood_p_alpha(N, k, p_alpha) for k in range(0, k_t+1)) #include k_t
    print('For alpha summing over the indices:')
    [print(k) for k in range(0, k_t+1)]
    
    # epsilon_beta
    epsilon_beta = sum(likelihood_p_beta(N, k, p_beta) for k in range(k_t+1, N+1))
    print('For beta summing over the indices:')
    [print(k) for k in range(k_t+1, N+1)]
    
    return epsilon_alpha, epsilon_beta


# Calculate likelihoods for each k = number of trials
L_p_beta = [likelihood_p_beta(N, k, p_beta) for k in k_values]
L_p_alpha = [likelihood_p_alpha(N, k, p_alpha) for k in k_values]

# Threshold for printout
k_t = calculate_k_t(N, p_alpha, p_beta)

# Plotting
plt.plot(k_values, L_p_beta, 'o-', label="L(p_beta | k, N)")
plt.plot(k_values, L_p_alpha, 'o-', label="L(p_alpha | k, N)")
plt.axvline(x=k_t, color='red', linestyle='--', label=f'k_t = {k_t}')
plt.xlabel("Number of trials, k")
plt.ylabel("Likelihood, L")
plt.legend()
plt.title(f"Likelihoods vs number of trials \nN={N}, p_alpha={p_alpha}, p_beta={p_beta}")
plt.show()

epsilon_alpha, epsilon_beta = calculate_epsilon_alpha_beta(N, p_alpha, p_beta)
print("εα =", epsilon_alpha)
print("εβ =", epsilon_beta)